#include "vanila/baseobject.h"
#include "vanila/dsobject.h"
#include "vanila/object.h"
#include "vanila/disassembler.h"
#include "vanila/chunk.h"
#include "utils/format.h"
#include <string>
#include <iostream>

namespace vanila
{
//! \brief disassemble a bytecode at offset in the chunk
//! \param[in] chunk the chunk pointer
//! \param[in] offset the offset of bytecode in code's vector
//! \return size_t the next bytecode' offset
size_t Disassembler::disassembleInstruction(const Chunk& chunk, size_t offset) noexcept
{
    std::cout << utils::format("%04d ", offset);

    if (offset > 0 && chunk.lines()[offset] == chunk.lines()[offset - 1])
        std::cout << "   | ";
    else
        std::cout << utils::format("%4d ", chunk.lines()[offset]);

    OpCode instruction = chunk.readBytecode(offset);
    switch (instruction)
    {
    case OpCode::CONSTANT:
        return Self::constantInstruction("OP_CONSTANT", chunk, offset);
    case OpCode::NIL:
        return Self::simpleInstruction("OP_NIL", offset);
    case OpCode::TRUE:
        return Self::simpleInstruction("OP_TRUE", offset);
    case OpCode::FALSE:
        return Self::simpleInstruction("OP_FALSE", offset);
    case OpCode::POP:
        return Self::simpleInstruction("OP_POP", offset);
    case OpCode::GET_LOCAL:
        return Self::byteInstruction("OP_GET_LOCAL", chunk, offset);
    case OpCode::SET_LOCAL:
        return Self::byteInstruction("OP_SET_LOCAL", chunk, offset);
    case OpCode::GET_GLOBAL:
        return Self::constantInstruction("OP_GET_GLOBAL", chunk, offset);
    case OpCode::DEFINE_GLOBAL:
        return Self::constantInstruction("OP_DEFINE_GLOBAL", chunk, offset);
    case OpCode::SET_GLOBAL:
        return Self::constantInstruction("OP_SET_GLOBAL", chunk, offset);
    case OpCode::GET_UPVALUE:
        return Self::byteInstruction("OP_GET_UPVALUE", chunk, offset);
    case OpCode::SET_UPVALUE:
        return Self::byteInstruction("OP_SET_UPVALUE", chunk, offset);
    case OpCode::GET_PROPERTY:
        return Self::constantInstruction("OP_GET_PROPERTY", chunk, offset);
    case OpCode::SET_PROPERTY:
        return Self::constantInstruction("OP_SET_PROPERTY", chunk, offset);
    case OpCode::GET_SUPER:
        return Self::constantInstruction("OP_GET_SUPERE", chunk, offset);
    case OpCode::EQUAL:
        return Self::simpleInstruction("OP_EQUAL", offset);
    case OpCode::GREATER:
        return Self::simpleInstruction("OP_GREATER", offset);
    case OpCode::LESS:
        return Self::simpleInstruction("OP_LESS", offset);
    case OpCode::ADD:
        return Self::simpleInstruction("OP_ADD", offset);
    case OpCode::SUBTRACT:
        return Self::simpleInstruction("OP_SUBTRACT", offset);
    case OpCode::MULTIPLY:
        return Self::simpleInstruction("OP_MULTIPLY", offset);
    case OpCode::DIVIDE:
        return Self::simpleInstruction("OP_DIVIDE", offset);
    case OpCode::NOT:
        return Self::simpleInstruction("OP_NOT", offset);
    case OpCode::NEGATE:
        return Self::simpleInstruction("OP_NEGATE", offset);
    case OpCode::JUMP:
        return Self::jumpInstruction("OP_JUMP", 1, chunk, offset);
    case OpCode::JUMP_IF_FALSE:
        return Self::jumpInstruction("OP_JUMP_IF_FALSE", 1, chunk, offset);
    case OpCode::LOOP:
        return Self::jumpInstruction("OP_LOOP", -1, chunk, offset);
    case OpCode::CALL:
        return Self::byteInstruction("OP_CALL", chunk, offset);
    case OpCode::INVOKE:
        return Self::invokeInstruction("OP_INVOKE", chunk, offset);
    case OpCode::SUPER_INVOKE:
        return Self::invokeInstruction("OP_SUPER_INVOKE", chunk, offset);
    case OpCode::CLOSURE:
    {
        offset++;
        uint8_t constant = chunk.readByte(offset++);
        
        std::cout << utils::format("%-16s %4d ", "OP_CLOSURE", constant);
        chunk.getConstant(constant).print();
        std::cout << '\n';
        
        ObjectFunction* function = chunk.constants()[constant].asFunction();
        for (uint32_t j = 0; j < function->upvalueCount(); ++j)
        {
            int isLocal = static_cast<int>(chunk.readBytecode(offset++));
            int index = static_cast<int>(chunk.readBytecode(offset++));

            std::cout << utils::format("%04d      |                     %s %d\n",
               offset - 2, isLocal ? "local" : "upvalue", index);
        }

        return offset; 
    }
    case OpCode::CLOSE_UPVALUE:
        return Self::simpleInstruction("OP_CLOSE_UPVALUE", offset);
    case OpCode::RETURN:
        return Self::simpleInstruction("OP_RETURN", offset);
    case OpCode::CLASS:
        return Self::constantInstruction("OP_CLASS", chunk, offset);
    case OpCode::INHERIT:
        return Self::simpleInstruction("OP_INHERIT", offset);
    case OpCode::METHOD:
        return Self::constantInstruction("OP_METHOD", chunk, offset);
    case OpCode::FIELD:
        return Self::constantInstruction("OP_FIELD", chunk, offset);
    case OpCode::LIST:
        return Self::shortInstruction("OP_LIST", chunk, offset);
    case OpCode::DICT:
        return Self::shortInstruction("OP_DICT", chunk, offset);
    case OpCode::SET:
        return Self::shortInstruction("OP_SET", chunk, offset);
    case OpCode::GET_INDEX:
        return Self::simpleInstruction("OP_GET_INDEX", offset);
    case OpCode::SET_INDEX:
        return Self::simpleInstruction("OP_SET_INDEX", offset);
    default:
        std::cout << "Unknown opcode " << static_cast<uint8_t>(instruction) << '\n';
        return offset + 1;
    }
}

//! \brief disaemble the all bytecode in chunk
//! \param[in] chunk the chunk pointer
//! \param[in] name task name, will be print first
void Disassembler::disassembleChunk(const Chunk& chunk, const char* name) noexcept
{
    std::cout << "== " << name << " ==\n";

    for (int offset = 0; offset < chunk.size();)
        offset = Self::disassembleInstruction(chunk, offset);   
}

//! \brief disassemble a simple instruction
//! \param[in] name the bytecode's name
//! \param[in] offset the bytecode's offset
//! \return size_t the next bytecode' offset 
size_t Disassembler::simpleInstruction(const char* name, size_t offset) noexcept
{
    // just print the name, adn return next offset
    std::cout << utils::format("%-16s", name) << '\n';
    return offset + 1;
}

//! \brief disassemble a constant instruction
//! \param[in] name the bytecode's name
//! \param[in] chunk the chunk pointer, it constain the constants vector
//! \param[in] offset the bytecode's offset
//! \return size_t the next bytecode' offset 
size_t Disassembler::constantInstruction(const char* name, const Chunk& chunk, int offset) noexcept
{
    // get constant index
    uint8_t constant = chunk.readByte(offset + 1);
    std::cout << utils::format("%-16s %4d '", name, constant);
    chunk.getConstant(constant).print();
    std::cout << "'\n";

    // the constant bytecode need 2 code
    return offset + 2;
}

//! \brief disassemble a byte instruction
//! \param[in] name the bytecode's name
//! \param[in] chunk the chunk pointer, it constain the constants vector
//! \param[in] offset the bytecode's offset
//! \return size_t the next bytecode' offset 
size_t Disassembler::byteInstruction(const char* name, const Chunk& chunk, int offset) noexcept
{
    uint8_t slot =  chunk.readByte(offset + 1);
    std::cout << utils::format("%-16s %4d\n", name, slot);
    return offset + 2; 
}

size_t Disassembler::shortInstruction(const char* name, const Chunk& chunk, int offset) noexcept
{
    uint16_t slot = chunk.readShort(offset + 1);
    std::cout << utils::format("%-16s %4d\n", name, slot);

    return offset + 3;
}

//! \brief jump a jump instruction
size_t Disassembler::jumpInstruction(const char* name, int sign, const Chunk& chunk, int offset) noexcept
{
    uint16_t jump = (chunk.readByte(offset + 1) << 8);
    jump |= chunk.readByte(offset + 2);
    std::cout << utils::format("%-16s %4d -> %d\n", name, offset, offset + 3 + sign * jump);
    
    return offset + 3;
}

// \brief invoke instruction
size_t Disassembler::invokeInstruction(const char* name, const Chunk& chunk, int offset) noexcept
{
    uint8_t constant = chunk.readByte(offset + 1);
    uint8_t argCount = chunk.readByte(offset + 2);
    
    std::cout << utils::format("%-16s (%d args) %4d '", name, argCount, constant);
    chunk.getConstant(constant).print();
    std::cout << "'\n";

    return offset + 3;
}

}